{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Vision Transformer for Text Classification on CIFAR-10"
      ],
      "metadata": {
        "id": "CTA99mzwHTlA"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "guLVvLCRz7Hw"
      },
      "outputs": [],
      "source": [
        "# Imports\n",
        "import math\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torchvision.transforms as transforms\n",
        "import torchvision.datasets as datasets\n",
        "import matplotlib.pyplot as plt\n",
        "from sklearn.metrics import accuracy_score"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ],
      "metadata": {
        "id": "gAhi19nrQK1O"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "device"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "v5PX4GNhRtQR",
        "outputId": "b7d9a17c-7b9f-4f38-bcc6-30fe273afb29"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "device(type='cuda')"
            ]
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Loading Dataset and Transforming it"
      ],
      "metadata": {
        "id": "nTiTCZCaHX1y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def patching(image_tensor, patch_size=8):\n",
        "  patches = image_tensor.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)\n",
        "  patches = patches.contiguous().view(3, 16, 8, 8)\n",
        "  return patches\n",
        "\n",
        "my_transforms = transforms.Compose([transforms.ToTensor(),\n",
        "                                    transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]),\n",
        "                                    ])#transforms.Lambda(lambda x: patching(x, patch_size=8))])"
      ],
      "metadata": {
        "id": "K9hcc3Ex0LNX"
      },
      "execution_count": 36,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Total 10 Labels in CIFAR-10 as shown below"
      ],
      "metadata": {
        "id": "Pa6tXfgbHgUA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Labels\n",
        "labels = \"\"\"airplane\n",
        "automobile\n",
        "bird\n",
        "cat\n",
        "deer\n",
        "dog\n",
        "frog\n",
        "horse\n",
        "ship\n",
        "truck\"\"\".split()\n",
        "labels = {i: name for i, name in enumerate(labels)}"
      ],
      "metadata": {
        "id": "F9o7HTFW9BGs"
      },
      "execution_count": 37,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Original Image Size of 32x32 and 10 Classes Present in the dataset\n",
        "train = datasets.CIFAR10('./', download=True, train=True, transform=my_transforms)\n",
        "test = datasets.CIFAR10('./', download=True, train=False, transform=my_transforms)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YgjdlLYU0C_h",
        "outputId": "8278d1c1-959b-4a30-f54b-25892bf2598b"
      },
      "execution_count": 38,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Files already downloaded and verified\n",
            "Files already downloaded and verified\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_dataloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True)\n",
        "test_dataloader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=True)"
      ],
      "metadata": {
        "id": "F4_BBOqg0mJU"
      },
      "execution_count": 39,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "images, label = next(iter(train_dataloader))"
      ],
      "metadata": {
        "id": "iqngvHEk0rKA"
      },
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "images.shape,"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "r5zU1Hc80zb3",
        "outputId": "cd203c6f-543c-4aed-fa61-ddec3db414da"
      },
      "execution_count": 41,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(torch.Size([128, 3, 32, 32]),)"
            ]
          },
          "metadata": {},
          "execution_count": 41
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "[labels[i] for i in label.numpy()]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nfuaxFJm7Xkk",
        "outputId": "2db23b25-23ed-4b7f-f1ce-2141c5402036"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['horse',\n",
              " 'dog',\n",
              " 'bird',\n",
              " 'truck',\n",
              " 'ship',\n",
              " 'cat',\n",
              " 'airplane',\n",
              " 'ship',\n",
              " 'ship',\n",
              " 'airplane',\n",
              " 'ship',\n",
              " 'dog',\n",
              " 'dog',\n",
              " 'frog',\n",
              " 'automobile',\n",
              " 'bird',\n",
              " 'bird',\n",
              " 'horse',\n",
              " 'deer',\n",
              " 'horse',\n",
              " 'horse',\n",
              " 'dog',\n",
              " 'horse',\n",
              " 'dog',\n",
              " 'airplane',\n",
              " 'dog',\n",
              " 'cat',\n",
              " 'ship',\n",
              " 'deer',\n",
              " 'dog',\n",
              " 'horse',\n",
              " 'ship',\n",
              " 'bird',\n",
              " 'ship',\n",
              " 'deer',\n",
              " 'airplane',\n",
              " 'frog',\n",
              " 'frog',\n",
              " 'truck',\n",
              " 'bird',\n",
              " 'ship',\n",
              " 'ship',\n",
              " 'dog',\n",
              " 'deer',\n",
              " 'automobile',\n",
              " 'truck',\n",
              " 'truck',\n",
              " 'dog',\n",
              " 'horse',\n",
              " 'truck',\n",
              " 'dog',\n",
              " 'bird',\n",
              " 'horse',\n",
              " 'truck',\n",
              " 'bird',\n",
              " 'dog',\n",
              " 'deer',\n",
              " 'dog',\n",
              " 'truck',\n",
              " 'deer',\n",
              " 'airplane',\n",
              " 'dog',\n",
              " 'frog',\n",
              " 'cat',\n",
              " 'dog',\n",
              " 'bird',\n",
              " 'airplane',\n",
              " 'truck',\n",
              " 'deer',\n",
              " 'automobile',\n",
              " 'truck',\n",
              " 'dog',\n",
              " 'ship',\n",
              " 'horse',\n",
              " 'cat',\n",
              " 'truck',\n",
              " 'frog',\n",
              " 'cat',\n",
              " 'horse',\n",
              " 'truck',\n",
              " 'horse',\n",
              " 'airplane',\n",
              " 'dog',\n",
              " 'dog',\n",
              " 'ship',\n",
              " 'dog',\n",
              " 'cat',\n",
              " 'horse',\n",
              " 'dog',\n",
              " 'deer',\n",
              " 'frog',\n",
              " 'cat',\n",
              " 'ship',\n",
              " 'truck',\n",
              " 'bird',\n",
              " 'ship',\n",
              " 'dog',\n",
              " 'dog',\n",
              " 'dog',\n",
              " 'cat',\n",
              " 'frog',\n",
              " 'dog',\n",
              " 'airplane',\n",
              " 'bird',\n",
              " 'frog',\n",
              " 'airplane',\n",
              " 'dog',\n",
              " 'horse',\n",
              " 'horse',\n",
              " 'frog',\n",
              " 'deer',\n",
              " 'dog',\n",
              " 'truck',\n",
              " 'truck',\n",
              " 'automobile',\n",
              " 'truck',\n",
              " 'dog',\n",
              " 'cat',\n",
              " 'horse',\n",
              " 'ship',\n",
              " 'deer',\n",
              " 'deer',\n",
              " 'ship',\n",
              " 'dog',\n",
              " 'ship',\n",
              " 'horse',\n",
              " 'bird',\n",
              " 'dog']"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Plotting Image Patches"
      ],
      "metadata": {
        "id": "DoYYJ7fwHQvY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_patches(patches, label, ncols=4):\n",
        "    n = patches.shape[1]\n",
        "    nrows = math.ceil(n / ncols)\n",
        "    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows))\n",
        "    for i, ax in enumerate(axes.flat):\n",
        "        if i < n:\n",
        "            ax.imshow(patches[:,i,:,:].permute(1, 2, 0))  # Convert tensor back to image format\n",
        "            ax.axis('off')\n",
        "        else:\n",
        "            ax.axis('off')\n",
        "    plt.subplots_adjust(wspace=0.05, hspace=0.05)\n",
        "    fig.suptitle(f\"Label: {label}\", fontsize=30)\n",
        "idx = 26\n",
        "plot_patches(images[idx], labels[label[idx].item()], 4)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 668
        },
        "id": "aB_LSfJ71S6U",
        "outputId": "21ca9aee-a398-470f-85ba-4fb4554b8a86"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 400x400 with 16 Axes>"
            ],
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFvCAYAAADHfzmDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmQUlEQVR4nO3de3CU9dn/8SvLsiwxLEsMMYQ0CTGEU+BHBQGtCEXLUKvoo089PFht/Wnrrx4frViLo7bWnjzUqmMdaysyHbVWSlHHWrWCgEoBKWIEjAETjCFAWJawLmFZ9/79oYRsDt/rWggC8n7NZCYbrv1+7733zoc7yfeQ5XmeJwCATvkO9QEAwOGMkAQAB0ISABwISQBwICQBwIGQBAAHQhIAHAhJAHAgJAHAgZAEAAdC8gh1xx13SFZWVuvHwoULD/UhHZDZs2envZ7Zs2cf6kMCRISQBACnozYkS0tLv1R3YgAOjqM2JIGjxcKFC9NuCO64445DfUhHFEISABwISQBwICQBwIGQBAAHQhIAHPyH+gCOBlu2bJGqqipZv369RKNRSSaTkpubKwUFBTJ+/HgpKCjo9j537twpixcvlg8++EA++eQT6d+/v5SVlcmpp54qPXv27JY+Nm7cKCtWrJDNmzfL9u3bpW/fvlJQUCBf+9rXDsprOtwkEgn597//LXV1dbJ161aJx+PSp08fKSkpkcrKSjn++OMzbjMajUpVVZW8//77sn37dkkkEhIOhyU/P19OPPFEKSkpOQivBE7eUaqkpMQTkdaPBQsWdFvbqVTKW7RokXf11Vd7Q4YMSeuns4+RI0d6s2fP9vbs2WPu4/bbb+/0+Lds2eJdccUVXq9evTrtKzc317vlllu8Xbt27ddr2717t3f//fd7w4cP7/L1ZGVleWPHjvXmz59vbvfxxx9Pa+Pxxx/fr+P7Irz55pveWWed5WVnZzvf19LSUu9HP/qRV1NT42xv5cqV3syZM73Ro0d7WVlZzjbLysq83/3ud148HlePU7vuuvrozu+FLwNC8iBcGDfeeON+XZynnHKKt3nzZlMfnYXk6tWrvQEDBpj6qqio8D788MOMXtfSpUu9QYMGZfSazjrrLC8Wi6lt709Itn8PD3awNjc3e+edd17G7+ukSZO6bPPBBx/cr2tl+PDh3gcffOA8XkKye/Dj9kHQ0tLS4Wv9+vWTgoICCYVCsnv3btm8ebNs2rQprWbJkiUyZcoUWb58ufTu3TujPjdv3iwzZsxIa7OoqEgKCgokEolIbW2tpFKp1n+rrq6WKVOmyJIlS6SwsFBt//nnn5cLLrhAdu3alfb1QCAggwYNkr59+0pzc7PU1NRIMplMe96UKVPk9ddfl2AwmNFrOpzU19fLtGnT5L333uvwb3369JGBAwdKKBSSHTt2SF1dXafXQGc6q+vTp48UFhZK3759JZlMytatW+Wjjz5Kq1mzZo1MnDhR3nnnHcnPz9+/FwWbQ53Sh8rBvJO86qqrvD59+njf+973vGeffdZraGjotK6+vt771a9+5YXD4bRjufbaa9U+2t9Jtn09F154obd27dq0+o8//ti76aabvB49eqQ9b9q0aWpfVVVVXu/evdOeN3HiRO+FF17o8GN7c3Oz9+ijj3rHHXdcWv2VV17p7ONwvpPcvXu3N27cuA53XOecc463ePFiL5lMptXv2bPHW758uTdr1iyvuLjYeSd59913e7169fLOP/98b86cOV5dXV2ndU1NTd7vf/97r7CwMO0Ypk+f3mXbr7zyivfKK69499xzT9pzvvOd77T+W2cfkUhkv87TlxUheRBCcsWKFd6OHTvM9bW1tWk/xvbu3dvbtm2b8zntQ3Lvx89//nPn8+bOndshKJ955pku6/fs2eNVVlam1f/0pz/1UqmUs5/6+npv8ODBac9buXJll/WHc0j++Mc/TusnEAh4Tz75pOm5iUTCW7RoUZf/XlVVZf4Vi+d5XiQS8caMGZN2PO+9957zOQsWLEirv/322839wfMYAnQQjBkzRkKhkLm+pKRE/vCHP7Q+3rVrlzz99NMZ93vmmWfKrFmznDXnnnuu3HTTTWlf++1vf9tl/bPPPitVVVWtj3/wgx/IbbfdJllZWc5+Bg4cKHPnzhWfb98ldu+99zqfcziKRCLy4IMPpn3toYcekosuusj0/J49e8rEiRO7/PcRI0Zk9ONyv3795Kmnnko7rywrd3ARkoeJ0047TQYMGND6+M0338y4jV//+temulmzZqWF+FtvvSVr167ttPb+++9v/Tw7O1t++ctfmo9n5MiRcvbZZ7c+nj9/vnz66afm52tqa2vF++ynIfE8T7773e92W9t7PfbYY/LJJ5+0Pp44caJcccUV3d5PJgYPHizjxo1rfbw/1wrsCMnDSGlpaevn//nPfzJ67pgxY2T48OGm2pycHDnvvPPSvvb66693qNu2bZssW7as9fGZZ54p/fr1y+i4pk6d2vp5LBbL+HUdai+//HLa42uvvfYQHUm6QYMGtX5+pJ3TIw0heZDV1tbKb37zG/n2t78tw4YNk/79+0uvXr3Slq7a+/HWW2+1Pq+pqSmjfiZPnnxA9W3DcK8lS5aI53mtj8eOHZtRHyIixcXFaY+7umM9HCWTSVm6dGnrY5/PJ9OmTTto/W3evFkefPBBmTFjhowcOVKOO+446d27d6fXylNPPdX6vHg83mHUAboPQ4AOkrq6OrnuuuvkueeeSwsaq2g0mlF9ZWXlAdV/+OGHHWraB9rMmTNl5syZGfXTXiQSOaDnf5EaGxvTftQeMmSI5OTkdHs/TU1NMnPmTJkzZ85+/zoiGo1mPGwMNoTkQbBs2TKZOnWq7NixY7/bSCQSGdUfe+yxB1TfWShv27YtozYtDuScfNHaB/rBGI+4fv16mTx5stTX1x9QO7t37+6mI0J7hGQ327Ztm5xxxhkdwmDUqFEyceJEKS8vl8LCQundu7cEg8G0vxLfeOONsnr16v3qNzs7O6P6Y445Ju1xLBbrUJPp3axF2wHth7udO3emPe7uu8hEIiFnnHFGh4AcPHiwTJo0SYYMGSIDBw6UY445pvXH7r3uvvvuDr8vxcFBSHazu+66K+0ObPDgwfLnP/857a+RXck06NqKx+MZ1bf9MVKk8wBofzzXX3+9fOtb38r84NooKys7oOd/kfr06ZP2uLP/SA7EI488ItXV1a2PjzvuOJk9e7bp955//OMfu/VY0DVCspv95S9/af08GAzKSy+9ZA6GA/l9XaZ/6Gn/o3Q4HO5Qk5eXl/Z4wIABcvrpp2d8bEeq3NzctMdbtmzp1vbbj4WdN2+enHTSSabnHkm/2z3S8dftbrRx40ZpaGhofTxt2jRzQO7atavTP55YtR3wbfHuu++mPW47pKSrr9XU1GR+YEewgoKCtDvs999/v9vuJlOplCxfvrz18ejRo80BKSKdziHHwUFIdqPNmzenPR4yZIj5uYsXL5Y9e/bsd9+djXPMpL6zXwd8/etfT3v82muvZX5gRzC/358WXKlUSl566aVuaXvbtm1pC4Fkcq1UV1fLxx9/bK5vOztHRPZrtMXRjJDsRu0vvkz+Qv3www8fUN9vv/22rFmzxlQbi8Vk7ty5aV+bNGlSh7qBAwemDRVav369/OMf/zig4zzStP/94AMPPNAt7X6R10r7P9Jl+vvrox0h2Y3ar8a9ZMkS0/NefPFFmT9//gH3f/PNN5vq7rrrLmlubm59PGHCBBk2bFinte3neV9//fVH1DCeA3XZZZel/QFn8eLFafPs99exxx4rfv++PwksXbo07c6yK6tWrco4JNv/bvVAfq1zNCIku1FxcbEMHDiw9fHy5cvT/pDTmWXLlsnFF1/cLf2/8MIL8otf/MJZM2/ePLn77rvTvnb99dd3WT9jxgwZMWJE6+Pq6mr55je/mfa7V82ePXvkiSeeMM8ttyotLU2bhXIwFnoIh8Ny3XXXpX3t6quvNi9AsmfPHlm8eHGHr/fo0UPGjx/f+njTpk3qAiA1NTVy9tlnZ/xrmZKSkrTfrf7rX/+S7du3Z9TG0Yy/bn/u7bffNv1P3pm2f/G95JJL0haBuOSSS2TDhg1y1VVXpS0qUV9fL4888ojcc889snv3bgkGg1JQUCC1tbX7dQwlJSVSV1cns2bNkqqqKrn99tvTfs+1adMmuf/+++Xee+9Nm9UxdepUueCCC7pst0ePHjJ37lwZP3586x3kW2+9JZWVlXLNNdfIjBkzpKKiosPzNm/eLMuXL5fnn39e5s2bJ1u3bpVLL710v17boXbbbbfJq6++2jpFMZFIyEUXXSR//etf5YYbbpAJEyZIjx49WuuTyaS88847Mm/ePJkzZ46UlZXJwoULO7R7ySWXyBtvvNH6+JZbbpGtW7fKzJkz0wauNzU1yRNPPCF33nmn7NixQ7KysqSiokLef/990/H7fD457bTTWn9aiUajMmHCBLnssstk6NChHX4cHzNmTMZz9L/UDt0qbYdW+7UID+SjrW3btnlFRUUdavx+vzdixAhv3Lhx3qBBgzrsZfLoo496kyZN6rLd9tqvJ/n00093WJC1uLjYO/HEE73y8nLP5/N1OKaSkhLvo48+Mp2v1157zevXr1+nrz8vL8+rrKz0xo8f7w0bNszr379/p3WXXnppl+0fzutJet5n62OOGDGi09fVp08fb/jw4d748eO9oUOHesFgMO3fu1p0N5FIeKNHj+7Qns/n84YMGeKNHz/eKy8v77D+509+8hPv0ksvTfuathXHwoUL1f1z9n6wfUM6QrKbQ9LzPtvYqf3K3F19+Hw+77777vM8zzugkFywYIH37rvvdgjKrj7Ky8u99evXZ3TOampqvBNPPHG/zlFWVpZ36623dtn24R6Snud50WjUmz59esav3bUyeV1dXYfFiV0fN9xwg5dKpTIOSc/zvAceeMDr2bMnIZkhfid5EHz1q1+Vt99+Wy6++OK0H8PaysrKkm984xuydOlS+d///d9u6beyslJWrVoll19+ufTq1avTmn79+snNN98sq1evznj2y/HHHy/Lli2T5557TqZMmSKBQMBZ36NHDznppJPkZz/7mdTU1Midd96ZUX+Hm759+8r8+fNlwYIFMnXqVPX1DxkyRG699VaZM2dOlzXFxcWyfPlyufrqq517AE2YMEH++c9/yr333qsueNyVa665RtauXSu33XabTJkyRQoLCyU7O3u/2ztaZHkeg6YOpkgkIosWLZK6ujrZuXOnHHPMMTJo0CA5+eSTD+oGTs3NzWn7bufl5UlZWZlMnjy52/bdjsfjsnTpUvnoo49k27ZtsmvXLsnJyZG8vDwZMmSIDBs2rMPvu75MPvnkE3njjTekvr5empqa5NNPP5VQKCSDBg2SUaNGSVFRUUbtxWIxWbx4sdTU1MiOHTukd+/e8pWvfEUmTJjQYck5fHEISQBw4MdtAHAgJAHAgZAEAAfzYHJfxRl6Y35b5radjnUgbVn7CwT0/rS/VGZi42v7/ppZOlWfTeMzvguWY/QbGvMZOwxK139t3cc2AN9nWGx36d8eSns84Zwr9d6N/82nDsPbAevyw5ZJDlV/f6z188pzLlfr2y960ZVEQj9Ky0LK1sWWLa81mbS1lTBcmluWPKnWHIaXDgAcPghJAHAgJAHAgZAEAAdCEgAcCEkAcCAkAcDBPE4yJ1sfM2dN3EBQH+9nGcblD9jGS1mGBfrEtseIP8PxlMHgFzve0zb+zTpCr8VYpzMObUuXYxhPa34tOtN4P+vCzIb3wXrkvgxfo99wiaZStteRMnxfpCzHZx3Pavk+NLZlHQ/cTd0BwNGJkAQAB0ISABwISQBwICQBwIGQBAAHQhIAHAhJAHAgJAHAwTwkPZGIqDXWxE2JPiUgYFmZ3Hj4PsPy1daZNIEMl7j2JfRZK/6U8W1IGGYjWFZ99xlncPj0WRk+y/QOEUnsx5SbRFy/5iwrsYvYLnTLbKXunHFjnRHSYnnf20i16Nec9Sr2JQ1979d0qv3vL2VZclxEksY6DXeSAOBASAKAAyEJAA6EJAA4EJIA4EBIAoADIQkADoQkADiYB5N//5LpemOmrQNEsoOWrSAsbdn6sxxXOBwytWXaV6KN7194rlqTnW0bkO03DBS3bArgS9kGJ0eqFqo1AWk2tZU/+nRTXVuXn69fc9adDYIB/Zqz7KJh2qpARFKG6ySRsLXltwzobuOH/32GWhM3TorwGSYemMaSG7bGEBFJxuNqTcw6mLybtvbgThIAHAhJAHAgJAHAgZAEAAdCEgAcCEkAcCAkAcCBkAQAB/Ng8vPPnKLWWFcMDlpWzzasFO4PGFcmN60SbRxcm+H/K9OnnaLWWAfh+/z64FjLOUk0rjP19+yzj6k1gcQeU1ujxhaZ6tq6eOrJak2LL8fUVrZppLhh4LTxO8Y0kNk6KDqVbev0c+dOO1XvOlJlaqt+w2q1JlSu9+fLLjD154vrkxMSxjchYU43N+4kAcCBkAQAB0ISABwISQBwICQBwIGQBAAHQhIAHAhJAHAgJAHAwTwmvap6o1qTStpmEOQYltIPZRtqcrpv24OgYUsJa1ttNUf15eit/1f5fPr59RteR6Rqham/DSv12TSFxaam5LWHf63WDD3zV2mPq19+WO9/8pWm/qMJ/RynDO9Dyvj+WzYOsF5JLdKi1hTIvu1Hqle9qNbXPPtjU99VS7erNdNuu1utyZ3wP6b+An59GxWfWL6nRJLGGU16fwCALhGSAOBASAKAAyEJAA6EJAA4EJIA4EBIAoADIQkADvaR0Yal7QPGbQgClq0ZDG2lEglTfxbJFn3AroiIz3DsbZkG2BvPW8rwHkhSPyfZIdtWCkPHjlBr4omoqa0tDR+b6tp67W/6APTpeba2LK854TNskxCwbRcRCIbVmlSLbVB045In1ZqKHz7T+vnqP+kD7GtrPVPfzYZvsZThEm+J215rwnL9+ixD9UUCwcy+V7vCnSQAOBCSAOBASAKAAyEJAA6EJAA4EJIA4EBIAoADIQkADoQkADiYZ9wU5BlmGhi3b8gO6N36DPkdMC6lHzD05zdsKfFZW5mN4vcHLP8P2c6bz6/PNMjJ0V9HqtA2TSWcr7/WhiW2mTTVG0xlaTa8ptfU+/VZOSIio6f0UmvC+fqsnFTKMCtHRJLN+nu6aulaU1uLVuo1p/5w3+eRJn02TW64p6nvUHmlWpNTrNdIPGrqL9HSrNY0G2PL78tsq5WucCcJAA6EJAA4EJIA4EBIAoADIQkADoQkADgQkgDgQEgCgEOW53m2ddwB4CjEnSQAOBCSAOBASAKAg3kG+N9felWtSRi3jczJ1RfLCOfoh1YYsi02EQzoC0PEU7a2YqmQWjOqvKL185VrqtX6kN+2nW1enn5OQtl6W2teeMjU3+zrH1drNkZNTYk/rNc8WZ/+6/GflGepz6kwtCsiMu4EvSa/UK+J6+sviIhI/Ua9ZsU6W1uro3rNY23O3eWj9fM2buwwU99lZ1ym1uSfcI5ak4zbrvGWlohaE2sxbsFsWHBn2qmT1RruJAHAgZAEAAdCEgAcCEkAcCAkAcCBkAQAB0ISABzM4ySH5xSrNVXRWlNbkdpVak1RaUytyTOMqRIR8fn08VJBn2GjMxGJbTEUlf+q9dP6lX9Xy/ML8k19FwX1Tb4SNUvUmoYls0391RpOb4txX7QK/fLpoOwUQ43t1Ekiqtc0GcYtNjbZ+osZvrNabHvPSdi2b9u++rBekwokTG21GDapW7mySm/IZ7sfCxiOq6VFH/csIpJK2eo03EkCgAMhCQAOhCQAOBCSAOBASAKAAyEJAA6EJAA4EJIA4EBIAoCDecZNRbm+Ind5QYGprY0L9ZWxI397Xq1p1A9JRERCuXqNdcXpyBpD0fR9M25iT9+slhedcZqp78aI/n9askmf/ZCdm23qr3ToJ2pNoXHGyykT+tsK2zhz6v9Ra1rililQIs1rNqk1CcMMI33+yWcsE7iGGmfS+G0Ty/YxTDSJNa83NRXbol/wGyL6quOGEhERSfn03Q18xtjy+bvnHpA7SQBwICQBwIGQBAAHQhIAHAhJAHAgJAHAgZAEAAdCEgAczIPJE4362vYNC2eb2lrztD5QvHql3s7wyabupLBCr9lSY2srstpWt1d0kV7jP9mw/L2IBHIK1ZqckL4nQGNjo6m/kw3nbdyUIaa28gttEw3ayk7peyVUr9AHiYuINNXqNVHDabEOJs8zbGth3LVDQhmeuqGGrTKsr2NogV4ZTeoDwCPNth4D2fr160/aYise14/LgjtJAHAgJAHAgZAEAAdCEgAcCEkAcCAkAcCBkAQAB0ISABzMg8kDLXppU71tkPIaQ1lcH1MqQePK5Ja6WMLWVpVh7PJ/t/k8slOvr67ebOo7J09/If7ERrUmtsEz9Rc0jP/dUPO+qa0tMX257BPaPV63QR8MvGKFqXtpqNdrUoYVvYOGQeIiIlFDWyHDoG8Rkcqhg2yFn5t6aolaU7WmztRWdos+y2JCZaVasyUWNfUXa9FPXND6JhjasuBOEgAcCEkAcCAkAcCBkAQAB0ISABwISQBwICQBwIGQBAAHQhIAHMwzbmIblqo1J0z9vqmtYJG+DUH9qsfVGr9hVo6IyIYGvSbfOPshu7etbi/DRA/JrrW1NWGyPgNlzaJP1Zp1tt0ipMUwCylYa2trdc0Has2fbkh/PP727bbGDb5mqJlmmNgSNe4I8NC7ek3cUCMicu26D9Wa2y7f9/nKJfpsmpRx/4bGquVqTVn5NLWmuFj/nhcRqd6gb9mR8Nvu7SxbQVhwJwkADoQkADgQkgDgQEgCgAMhCQAOhCQAOBCSAOBASAKAg3kw+Q0zblJrzjz7JFNbk0+frNbkD52k1tQ0vm7qr6VFr4kZBpyLiNTustXtFTHUnHqubYR6QWm2WpMo09uJG9/1esPA6S3G81Y21FbX1v/tp9fUGsebTx2p14wapdds1Mc6i4hI4p96jeXaEBFpthZ+rilqKDIOio8YBp3H1qzR24kZTq6IJJP6xRmL2w7e7zOOmFdwJwkADoQkADgQkgDgQEgCgAMhCQAOhCQAOBCSAOBASAKAAyEJAA7mGTf65g0is+e/ZWprqqFu2vF6O6Unm7qTZsOMm9oNtrZitrJWIUubUX0mjYjImpX6Fggpw2ttMs7gePkFvabWs7V1Ql9bXVuVuXrN5DxbW37DGzH6wm+qNYEVS0z9/ax6p1qTslwcIlKT4cSR6qhek288bwFLQvjDakksnjL1l0zqLzYUyjG15UsZ9h+xtNMtrQDAlxQhCQAOhCQAOBCSAOBASAKAAyEJAA6EJAA4EJIA4JDleZ5xODAAHH24kwQAB0ISABwISQBwMC9w8Zu8LLVm4TZbWysNNZMNNZdcZOtvda1eE7Q1JSHDnPnLVuz7Ne+P/Pp5azR2fv7/6DXjDDt31tbb+rvvN3rNX7vxN9rtfz3+/Wz93E0Za2w8qpec+aOvqzXrNq4ydVf7gr7X7cv6TqwiIvKsvlaGRNqcu0CWft6mD7T1ffn3j1NrssderdYsqQ6Y+osb9n8OZNtiqzmqt3XPbT9Ta7iTBAAHQhIAHAhJAHAgJAHAgZAEAAdCEgAcCEkAcDCPk4wZdsDKN7Z1uqGm0jKOyzb0Sqacoe9CFanfYWprg2VHtDaChj2LygpsbY06ZaRaUzC2Uq3JjdSa+vtFob5hW/11pqZMY2PbKx2n10SM10CDYaO3pjsWqDWhUlt/YUNdnvG7b/sbtrq9LM0OLbe1VTH2FLWm1qdvZJdM2HYzSyb1+7ZQtu1Nb4o2m+o03EkCgAMhCQAOhCQAOBCSAOBASAKAAyEJAA6EJAA4EJIA4EBIAoCDecZNfple07DW1pZlQebEx3pNywpbfxeeUKTWVEdtM25WG1eT3itiaNaXsrX16ovvqjXlEb2x3Bzb/40hfx+15p6fxk1tPTb7U1NdW3Ne12vCxrYME8bkvQ/1mm8YakRECkr0mi1RW1v/NchWt9fs2/X3LbfQNuVmS0Bf6r6mQV+uPxWwXeThUFitiTdHTW2JGL+xFNxJAoADIQkADoQkADgQkgDgQEgCgAMhCQAOhCQAOBCSAOBgHkxe1aDXFA+xtZV4X6+xjNnOrrX19+jD76k1pfp488/6NGzH0FbEUBPaaWvrmb8Y+vuL/lqvv9TW38nj9FHMxRW2EzL6FH0gfHuGy+QL94q1sK77+rxueGb1uXlD1Zr6uO2Cb1ypDxSPpvRB2ym/bdJBS0yvC6b07SJERIL2eHPiThIAHAhJAHAgJAHAgZAEAAdCEgAcCEkAcCAkAcCBkAQAB/Noy0WGFban2MZ4mliGugZybW2tMYxKXm0cuVxrqLmvzechQ/3oYba+K5J6zYoP9Joq44ru/qC+DHd5xfGmtkIZDsIXERlpqMl8iPrh4/99zVY3/AR9pfG2GpuCak1BqW2Eem6OviVBNKXfa21o2Gjqr6kpqhcFbbEVS7SY6jTcSQKAAyEJAA6EJAA4EJIA4EBIAoADIQkADoQkADgQkgDgQEgCgIN5xs1aQ03BJltbeYaashK9JlRq66/QMH1n6Vu2tiwzaNrK7aXXBIxt+Q3vVtmxek1tja2/mKGmoWG9qa3AfszGuuF7hiJ95wAREbn/Cb3GsEOJTLFNMJICw2ywaVMGmNpqkkJbp59btWadWlMWrTe1dfF3f6jWhIZO1Y9pwyhTf39/eZFaU1tbbWor6TdeHAruJAHAgZAEAAdCEgAcCEkAcCAkAcCBkAQAB0ISABwISQBwMA8mt7AMPhYROXmMXhOPGBoybGcgIlJRqdesW2VrK7XLVrdXOGwoMv5XlTSsRp9vGMQcN46x3bJBr1n1nq2tvP56zdXtHjc26s/Jt8xMEJGLv6nXjBqVpdYUldsGRa+s1kfsv7zUNvuioNh4oX+uNF/fK8OfsA0m37jiT2pNoV8/Pn/wFFN/oVz9Ag7H8k1ttbB9AwAcfIQkADgQkgDgQEgCgAMhCQAOhCQAOBCSAOBASAKAAyEJAA7mGTe9DTWW5e9FRCzj4LMNy/1HDTMyRERWNes1PuOkBttY/338hr0ZgvoECRER2Wh4vUWlek2FcSuFZJVeU2trSjZsNRa28dw/9Bqf5cIUkbBhJlIq21NrNjS9Y+ovariewoZtRUREgonMTt7QPP3eZ/iE801tBUL6hiUvvrpMrXl22Wum/nKLT1BrcoK22GqOGr7xDbiTBAAHQhIAHAhJAHAgJAHAgZAEAAdCEgAcCEkAcCAkAcAhy/M8fQQtAByluJMEAAdCEgAcCEkAcDAvcJGVpW+3aXVaT70m37AIQyhs68+yFWvIsBCFiEiiSa95KL7v17zTDedt3EBb342GvnP09QjMb3q9YV0Fy86/IiJrDDU17X493p3X3ABDzXBD0fAKW38tcb2mvNjW1gmF+nk4/YF9+wRX36cvw1I4doKt86LJasmb6/RmnnzJsFqKiKT8Yb3G8g0tIinD1slPPvSwWsOdJAA4EJIA4EBIAoADIQkADoQkADgQkgDgQEgCgIN5nGR3+tcevearO/SaEwybO4mIZBvGDsZjtrYW7bLVtdYbaqIf29oKG2qaDWMbDUP4RERkeC+9xrfb1lalsc+2zjL0X2Xs33I3ULvJ0I7xO2ac4QUHDeP4RETi/sy+TTcm9Au+vto21rC2arXeVlwfl5mbb9v1zOfXB0hXr7ONuRTj5n4a7iQBwIGQBAAHQhIAHAhJAHAgJAHAgZAEAAdCEgAcCEkAcCAkAcDhkMy4sbCkd8LYVsowxSTebGtrrbHPvQwTh6TG2FZpN9XkGGayiIgUlBuKGmxt5eiTMjrIN/R/gvF9q/5IrzEs/C6vGNoREYmH9ZqiQltbjzyjT1Gbft++z698YL1aHwxvMfXtC+nT2saefLpak52TZ+ovIPoWAYFA0NRWY8R4cSq4kwQAB0ISABwISQBwICQBwIGQBAAHQhIAHAhJAHAgJAHA4ZAMJh9oqLHsprBmo62/YsP2DdFPbG1lqqehptTYlr6wvUhOP72mWR+vKyIiq9bpNZFPbW21bNdrHmr3eJHh/c0O2/pPfUWv8SV6qzUDLHuBiEhTUH+3EsYtQ4pGG/cp+dzJ0/9LrWmO2aZiRGP6Ng/xFn2fhJZExNRfIKDP/LAOJm9JGPfHUHAnCQAOhCQAOBCSAOBASAKAAyEJAA6EJAA4EJIA4EBIAoDDIRlMblgoXMpL9JqGOlt/bxoGMhsXiZbjjHV7nWqoKTjG1pZlde9qw5jdqs22/qKGmoIBtrZC4f62wjYS+frZLhtVZGqrqKhArdm4UV+tuyVhuXpFAn59wHO2z7Zadzg/s3uZvIIctSZRbxtMXpCjn7e8PP11NDY2mvqLxfTzGwrZBvQXFtquDQ13kgDgQEgCgAMhCQAOhCQAOBCSAOBASAKAAyEJAA6EJAA4EJIA4HBIZtwYJsDI64bZNH2N/ekL0IuU9rC1dWqpsdPPFZyk17QYV5mP6qvkS06RPgWmImHZCEKkwTC5JJWy7QWRnW+bXdJWXnGxWtNs3IuiarU+FSmV0t+IhPHNys3VjyuvzLYtQ1OkyVS3V9XqBrUmmG3rOxDUX6/htJlnyUQi+qynSNR2PsJhfbaQBXeSAOBASAKAAyEJAA6EJAA4EJIA4EBIAoADIQkADoQkADgcksHk3WVHN7a1xngmAgWZbeAQLTxNraltsA2OTYq+5H5lUYVaMzTXNrC3KKkPw4812wZXb9ygD3Buz5fSt0AI5IRNbTXHY2pNYaE+eN16W5Gfr59jf0o/JhGRlkbbgPm9AgF9M5JQ2HYNxGL6Ma5ZU63WDB9abuovN1cf5N64xbYVROOWelOdhjtJAHAgJAHAgZAEAAdCEgAcCEkAcCAkAcCBkAQAB0ISABwISQBwOKJn3HSnnbttddWNtpkKe63baJi1ErPNWinI07dASPn0WSr1xhk+4tf/Dw1k27ZlqG/St09orzmpn5d4s63d7KB+XuIt+oymZNKwh4aIJBJ6XTiUY2or6cvsXsafbdhyQWyvwx/QjzEc0mcERY3vf1NUrwvl5ZvaKgiHTXUa7iQBwIGQBAAHQhIAHAhJAHAgJAHAgZAEAAdCEgAcCEkAcMjyPM871AcBAIcr7iQBwIGQBAAHQhIAHAhJAHAgJAHAgZAEAAdCEgAcCEkAcCAkAcDh/wN8Q0A6lk7y3QAAAABJRU5ErkJggg==\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "5JoBNrpWI8oi"
      },
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Vision Transformer Model"
      ],
      "metadata": {
        "id": "ycVIqANlIIlV"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class EmbedLayer(nn.Module):\n",
        "    def __init__(self, image_dim):\n",
        "        super().__init__()\n",
        "        self.flatten = nn.Flatten(start_dim=2)\n",
        "        self.cls_token = nn.Parameter(torch.zeros(1, 1, image_dim), requires_grad=True)  # Cls Token\n",
        "        self.pos_embedding = nn.Parameter(torch.zeros(1, 1, image_dim, requires_grad=True))  # Positional Embedding\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = x.permute(0, 2, 3, 4, 1)\n",
        "        x = self.flatten(x) # Flattening Patches\n",
        "        x = torch.cat((torch.repeat_interleave(self.cls_token, x.shape[0], 0), x), dim=1)  # Adding classification token at the start of every sequence\n",
        "        x = x + self.pos_embedding  # Adding positional embedding\n",
        "        return x\n",
        "\n",
        "class EmbedLayerWithConv(nn.Module):\n",
        "    def __init__(self, image_size, embed_dim, n_channels=3, patch_size=8):\n",
        "        super().__init__()\n",
        "        self.conv1 = nn.Conv2d(n_channels, embed_dim, kernel_size=patch_size, stride=patch_size)  # Pixel Encoding\n",
        "        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)  # Cls Token\n",
        "        self.pos_embedding = nn.Parameter(torch.zeros(1, (image_size // patch_size) ** 2 + 1, embed_dim), requires_grad=True)  # Positional Embedding\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.conv1(x)  # B C IH IW -> B E IH/P IW/P (Embedding the patches)\n",
        "        x = x.reshape([x.shape[0], x.shape[1], -1])  # B E IH/P IW/P -> B E S (Flattening the patches)\n",
        "        x = x.transpose(1, 2)  # B E S -> B S E\n",
        "        x = torch.cat((torch.repeat_interleave(self.cls_token, x.shape[0], 0), x), dim=1)  # Adding classification token at the start of every sequence\n",
        "        x = x + self.pos_embedding  # Adding positional embedding\n",
        "        return x\n"
      ],
      "metadata": {
        "id": "P1WYPPkQ-Y4A"
      },
      "execution_count": 31,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class Classifier(nn.Module):\n",
        "    def __init__(self, embed_dim, mlp_size, n_classes):\n",
        "        super().__init__()\n",
        "        # Newer architectures skip fc1 and activations and directly apply fc2.\n",
        "        self.fc1 = nn.Linear(embed_dim, mlp_size)\n",
        "        self.activation = nn.ReLU()\n",
        "        self.fc2 = nn.Linear(mlp_size, n_classes)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = x[:, 0, :]  # Get CLS token\n",
        "        x = self.fc1(x)\n",
        "        x = self.activation(x)\n",
        "        x = self.fc2(x)\n",
        "        return x"
      ],
      "metadata": {
        "id": "dHZMZKhRKZ-l"
      },
      "execution_count": 32,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class VisionTransformer(nn.Module):\n",
        "    def __init__(self, image_dim, embed_dim, mlp_size, n_layers, n_attention_heads, n_classes):\n",
        "        super().__init__()\n",
        "        self.embedding = EmbedLayer(image_dim, embed_dim)\n",
        "        self.linear = nn.Linear(embed_dim, embed_dim)\n",
        "        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=n_attention_heads, batch_first=True)\n",
        "        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=n_layers, norm=nn.LayerNorm(embed_dim))\n",
        "        self.classifier = Classifier(embed_dim, mlp_size, n_classes)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.embedding(x)\n",
        "        x = self.linear(x)\n",
        "        x = self.transformer_encoder(x)\n",
        "        x = self.classifier(x)\n",
        "        return x\n",
        "\n",
        "class VisionTransformer(nn.Module):\n",
        "    def __init__(self, image_dim, embed_dim, mlp_size, n_layers, n_attention_heads, n_classes):\n",
        "        super().__init__()\n",
        "        self.embedding = EmbedLayerWithConv(image_dim, embed_dim)\n",
        "        self.linear = nn.Linear(embed_dim, embed_dim)\n",
        "        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=n_attention_heads, batch_first=True)\n",
        "        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=n_layers, norm=nn.LayerNorm(embed_dim))\n",
        "        self.classifier = Classifier(embed_dim, mlp_size, n_classes)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.embedding(x)\n",
        "        x = self.linear(x)\n",
        "        x = self.transformer_encoder(x)\n",
        "        x = self.classifier(x)\n",
        "        return x"
      ],
      "metadata": {
        "id": "ToGFrUoYKyXM"
      },
      "execution_count": 33,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "ViT = VisionTransformer(image_dim=32, embed_dim=144, mlp_size=3072,\n",
        "                        n_layers=4, n_attention_heads=12,\n",
        "                        n_classes=10)"
      ],
      "metadata": {
        "id": "l9asKUwQLaf0"
      },
      "execution_count": 46,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "out = ViT(images)"
      ],
      "metadata": {
        "id": "k6W9uI7iL6SA"
      },
      "execution_count": 47,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "ViT.embedding(images).shape"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VFp3sb1t8IO7",
        "outputId": "6dcb1e8a-19a2-4e1c-e7a8-00d2a6a96e06"
      },
      "execution_count": 48,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([128, 17, 144])"
            ]
          },
          "metadata": {},
          "execution_count": 48
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "ViT.linear(ViT.embedding(images)).shape"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "f3w1VKwO6E8B",
        "outputId": "1b89f0d6-76a3-4f4e-f699-9db96004f421"
      },
      "execution_count": 49,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([128, 17, 144])"
            ]
          },
          "metadata": {},
          "execution_count": 49
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "out.shape"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "c6ke5Ex9L_fK",
        "outputId": "8b5f654f-fa44-4079-ad79-dd60bbe5ceb3"
      },
      "execution_count": 50,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([128, 10])"
            ]
          },
          "metadata": {},
          "execution_count": 50
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "class ViT_Trainer:\n",
        "  def __init__(self, train_dataset, test_dataset, device):\n",
        "    self.device = device\n",
        "    # self.model = VisionTransformer(image_dim=8*8*3, embed_dim=256, mlp_size=3072,\n",
        "    #                     n_layers=2, n_attention_heads=8,\n",
        "    #                     n_classes=10)\n",
        "    self.model = ViT = VisionTransformer(image_dim=32, embed_dim=288, mlp_size=3072,\n",
        "                        n_layers=4, n_attention_heads=12,\n",
        "                        n_classes=10)\n",
        "    def count_parameters(model):\n",
        "      return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "\n",
        "    total_params = count_parameters(self.model)\n",
        "    print(f\"Total parameters in the model: {total_params/1e6}M\")\n",
        "    self.model = self.model.to(self.device)\n",
        "\n",
        "    self.optimizer = optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=1e-3)\n",
        "    self.warm_up_epochs = 10\n",
        "    self.linear_warmup = optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1/self.warm_up_epochs, end_factor=1.0, total_iters=self.warm_up_epochs, last_epoch=-1, verbose=True)\n",
        "    self.cos_decay = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=50-self.warm_up_epochs, eta_min=1e-5, verbose=True)\n",
        "\n",
        "    self.ce = nn.CrossEntropyLoss()\n",
        "    self.train_dataset = train_dataset\n",
        "    self.test_dataset = test_dataset\n",
        "\n",
        "  def train(self, epochs):\n",
        "    for epoch in range(epochs):\n",
        "      self.model.train()\n",
        "      for batch_idx, (images, labels) in enumerate(self.train_dataset):\n",
        "        images = images.to(self.device)\n",
        "        labels = labels.to(self.device)\n",
        "\n",
        "        self.optimizer.zero_grad()\n",
        "        outputs = self.model(images)\n",
        "        loss = self.ce(outputs, labels)\n",
        "        loss.backward()\n",
        "        self.optimizer.step()\n",
        "\n",
        "      if epoch < self.warm_up_epochs:\n",
        "        self.linear_warmup.step()\n",
        "      else:\n",
        "        self.cos_decay.step()\n",
        "\n",
        "      # if batch_idx % 50 == 0:\n",
        "      #   print(f\"Epoch {epoch+1}/{epochs}, Batch {batch_idx+1}/{len(self.train_dataset)}, Loss: {loss.item():.4f}\")\n",
        "\n",
        "      accuracy = self.test(self.test_dataset)\n",
        "      print(f\"Epoch {epoch+1}/{epochs}, Accuracy: {accuracy:.4f}\")\n",
        "\n",
        "  def test(self, loader):\n",
        "        self.model.eval()\n",
        "\n",
        "        actual = []\n",
        "        pred = []\n",
        "\n",
        "        for (x, y) in loader:\n",
        "            x = x.to(self.device)\n",
        "\n",
        "            with torch.no_grad():\n",
        "                logits = self.model(x)\n",
        "            predicted = torch.max(logits, 1)[1]\n",
        "\n",
        "            actual += y.tolist()\n",
        "            pred += predicted.tolist()\n",
        "\n",
        "        acc = accuracy_score(y_true=actual, y_pred=pred)\n",
        "        return acc"
      ],
      "metadata": {
        "id": "n1r74enWM32T"
      },
      "execution_count": 54,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "trainer = ViT_Trainer(train_dataloader, test_dataloader, device)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gzdyL6DqQD2c",
        "outputId": "c4f0fedf-3614-44ef-e23e-9dc2af6038aa"
      },
      "execution_count": 55,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Total parameters in the model: 8.643434M\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch/optim/lr_scheduler.py:28: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n",
            "  warnings.warn(\"The verbose parameter is deprecated. Please use get_last_lr() \"\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.train(50)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TMFJku6MqlcX",
        "outputId": "0ad0169d-f2ad-4426-c774-6772146190fa"
      },
      "execution_count": 56,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/50, Accuracy: 0.4116\n",
            "Epoch 2/50, Accuracy: 0.4881\n",
            "Epoch 3/50, Accuracy: 0.5313\n",
            "Epoch 4/50, Accuracy: 0.5691\n",
            "Epoch 5/50, Accuracy: 0.5765\n",
            "Epoch 6/50, Accuracy: 0.6008\n",
            "Epoch 7/50, Accuracy: 0.6162\n",
            "Epoch 8/50, Accuracy: 0.6234\n",
            "Epoch 9/50, Accuracy: 0.6291\n",
            "Epoch 10/50, Accuracy: 0.6255\n",
            "Epoch 11/50, Accuracy: 0.6320\n",
            "Epoch 12/50, Accuracy: 0.6519\n",
            "Epoch 13/50, Accuracy: 0.6461\n",
            "Epoch 14/50, Accuracy: 0.6524\n",
            "Epoch 15/50, Accuracy: 0.6583\n",
            "Epoch 16/50, Accuracy: 0.6597\n",
            "Epoch 17/50, Accuracy: 0.6588\n",
            "Epoch 18/50, Accuracy: 0.6617\n",
            "Epoch 19/50, Accuracy: 0.6634\n",
            "Epoch 20/50, Accuracy: 0.6630\n",
            "Epoch 21/50, Accuracy: 0.6701\n",
            "Epoch 22/50, Accuracy: 0.6791\n",
            "Epoch 23/50, Accuracy: 0.6763\n",
            "Epoch 24/50, Accuracy: 0.6838\n",
            "Epoch 25/50, Accuracy: 0.6735\n",
            "Epoch 26/50, Accuracy: 0.6789\n",
            "Epoch 27/50, Accuracy: 0.6888\n",
            "Epoch 28/50, Accuracy: 0.6877\n",
            "Epoch 29/50, Accuracy: 0.6870\n",
            "Epoch 30/50, Accuracy: 0.6884\n",
            "Epoch 31/50, Accuracy: 0.6885\n",
            "Epoch 32/50, Accuracy: 0.6927\n",
            "Epoch 33/50, Accuracy: 0.6923\n",
            "Epoch 34/50, Accuracy: 0.6975\n",
            "Epoch 35/50, Accuracy: 0.6970\n",
            "Epoch 36/50, Accuracy: 0.6949\n",
            "Epoch 37/50, Accuracy: 0.7042\n",
            "Epoch 38/50, Accuracy: 0.7047\n",
            "Epoch 39/50, Accuracy: 0.7051\n",
            "Epoch 40/50, Accuracy: 0.7077\n",
            "Epoch 41/50, Accuracy: 0.7083\n",
            "Epoch 42/50, Accuracy: 0.7073\n",
            "Epoch 43/50, Accuracy: 0.7152\n",
            "Epoch 44/50, Accuracy: 0.7141\n",
            "Epoch 45/50, Accuracy: 0.7109\n",
            "Epoch 46/50, Accuracy: 0.7123\n",
            "Epoch 47/50, Accuracy: 0.7118\n",
            "Epoch 48/50, Accuracy: 0.7157\n",
            "Epoch 49/50, Accuracy: 0.7170\n",
            "Epoch 50/50, Accuracy: 0.7156\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.train(10)"
      ],
      "metadata": {
        "id": "oo2o5vEm5E_r",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e3a49e3d-0c15-41c6-f21d-c14c88bdd142"
      },
      "execution_count": 57,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/10, Accuracy: 0.7149\n",
            "Epoch 2/10, Accuracy: 0.7156\n",
            "Epoch 3/10, Accuracy: 0.7173\n",
            "Epoch 4/10, Accuracy: 0.7139\n",
            "Epoch 5/10, Accuracy: 0.7161\n",
            "Epoch 6/10, Accuracy: 0.7149\n",
            "Epoch 7/10, Accuracy: 0.7158\n",
            "Epoch 8/10, Accuracy: 0.7175\n",
            "Epoch 9/10, Accuracy: 0.7166\n",
            "Epoch 10/10, Accuracy: 0.7164\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "P9a70R0tDIWw"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}